# ------------------------------------------------------------------------------
# Measures performance of limited bandwidth models
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 05_measure_performance.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(Metrics)
library(Matrix)
library(glue)

temp <- here(
  "code", "06_physician_boundedness", "01_behavioral_lasso", "temp"
)
u <- modules::use(here("lib", "util.R"))

# Helpers ----------------------------------------------------------------------
measure_performance <- function(score, stent, test, mace, which_obs = seq_along(score)) {
  score <- score[which_obs]
  stent <- stent[which_obs]
  test <- test[which_obs]
  mace <- mace[which_obs]

  design <- crossing(
    target_name = c("stent_or_cabg_010_day", "test_010_day", "macetrop_030_pos")
  )

  performance_tb <- mutate(design,
    score = map(target_name, switch,
      stent_or_cabg_010_day = score[test],
      test_010_day = score,
      macetrop_030_pos = score[!test]
    ),
    target = map(target_name, switch,
      stent_or_cabg_010_day = stent[test],
      macetrop_030_pos = mace[!test],
      test_010_day = test
    )
  )

  transmute(performance_tb,
    target_name,
    # measure_name,
    auc = map2_dbl(target, score, auc),
    r2 = map2_dbl(target, score, calibrated_r2) # ,
    # logloss = map2_dbl(target, score, LogLoss)
  ) %>%
    gather(
      key = "measure_name",
      value = "performance",
      auc, r2 # , logloss
    )
}

take_measure <- function(measure, actual, predicted) {
  measure(actual = actual, predicted = predicted)
}

calibrated_r2 <- function(actual, predicted) {
  summary(lm(actual ~ predicted))$r.squared
}

resid_sum_sq <- function(actual, predicted) {
  sum((actual - predicted)^2)
}

ppv <- function(actual, predicted, top_share) {
  mask <- rank(-predicted, ties.method = "random") < (length(predicted) * top_share)
  mean(actual[mask])
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
overnight_lab <- ""
paths <- read_yaml(here::here("lib", "filepaths.yml"))
ids <- readRDS(glue(paths$analysis$test_cohort))

score_files <- list.files(file.path(temp), full.names = TRUE) %>%
  str_subset("scores_")

score_tb <- map_df(score_files, readRDS)

# Exclusions -------------------------------------------------------------------
message("Exclusions...")
keep_obs <- which(!ids$exclude)
ids <- ids[keep_obs, ] %>% setDT()
ids[, ptid := match(ptid, unique(ptid))]
score_tb <- mutate(score_tb,
  score = map(score, `[`, keep_obs)
) %>%
  # filter(model_name == "lasso")
  filter(model_name == "glm")


# Bootstrap --------------------------------------------------------------------
message("Bootstrapping...")
unique_ptids <- unique(ids$ptid)
id_row_dict <- split(seq_along(ids$ptid), ids$ptid)

design_bootstrap <- tibble(
  iter = 1:1000,
  ptid = rerun(
    max(iter),
    sample(unique_ptids, replace = TRUE)
  ),
  rows = map(ptid, ~ unlist(id_row_dict[.x]))
) %>%
  select(-ptid)

# Fit --------------------------------------------------------------------------
message("Measuring performance...")
performance_tb <- score_tb %>%
  mutate(performance = map(score, measure_performance,
    stent = ids$stent_or_cabg_010_day,
    test = ids$test_010_day,
    mace = ids$macetrop_030_pos
  )) %>%
  unnest(performance)

message("Bootstrapping performance...")
bootstrap_tb <- design_bootstrap %>%
  mutate(performance_tb = map(
    rows,
    ~ {
      score_tb %>%
        mutate(performance = map(score, measure_performance,
          stent = ids$stent_or_cabg_010_day,
          test = ids$test_010_day,
          mace = ids$macetrop_030_pos,
          which_obs = .x
        )) %>%
        select(-score)
    }
  )) %>%
  select(-rows) %>%
  unnest(performance_tb) %>%
  unnest(performance)

# Save -------------------------------------------------------------------------
message("Saving...")
write_rds(performance_tb, file.path(temp, "performance_obs.rds"))
write_rds(bootstrap_tb, file.path(temp, "performance_boot.rds"))

# Done -------------------------------------------------------------------------
message("Done.")
